基于TensorFlow lite 在移动端部署模型

您所在的位置:网站首页 tensorflow mobile 基于TensorFlow lite 在移动端部署模型

基于TensorFlow lite 在移动端部署模型

#基于TensorFlow lite 在移动端部署模型| 来源: 网络整理| 查看: 265

1. 移动端深度学习框架比较

目前,移动端模型嵌入SDK 主要有 :Google , Tensor Flow Mobile(19年不再维护)是在推出Tensor flow(简称TF)时同时推出的 Tensor Flow Mobile和2017年11月14日 Google I/O 2017大会上的推出的Tensor Flow Lite(开发者预览版,简称 TF Lite);百度, 2017年9月25日发布MDL框架;腾讯,2017年7月NCNN框架正式开源和2018年5月8日开源FeatherCNN框架,小米,2018年6月28日发布框架MACE (Mobile AI Compute Engine )。

1.1 开源框架对比

部门开源框架对比

2.TensorFlow Lite

Tensor Flow Lite 是Google I/O 2017大会上的推出的,是专门针对移动设备上可运行的深度网络模型简单版,目前还只是开发者预览版,未推出正式版。官网:https://tensorflow.google.cn/lite/

2.1 Tensor Flow Lite组件

图 1 TensorFlow Lite 的架构设计

组件包括:

1)Tensor Flow 模型(Tensor Flow Model):保存在磁盘中的训练模型。

2)Tensor Flow Lite 转化器(Tensor Flow Lite Converter):将模型转换成Tensor Flow Lite 文件格式的项目。

3)Tensor Flow Lite 模型文件(Tensor Flow Lite Model File):基于Flat Buffers,适配最大速度和最小规模的模型。

2.1 模型部署

将 TensorFlow Lite 模型文件部署到移动 App 中:

Java API:安卓设备上适用于 C++ API 的便利封装。

C++ API:加载 TensorFlow Lite 模型文件,启动编译器。安卓和 iOS 设备上均有同样的库。

编译器(Interpreter):使用运算符执行模型。解释器支持选择性加载运算符;没有运算符时,编译器只有 70KB,加载所有运算符后,编译器为 300KB。这比 TensorFlow Mobile(具备一整套运算符)的 1.5M 要小得多。

在选择的安卓设备上,编译器将使用安卓神经网络 API 进行硬件加速,或者在无可用 API 的情况下默认执行 CPU。

开发者还使用 C++ API 实现自定义 kernel,它可被解释器使用。

3. Android端模型嵌入

图2 TensorFLow Lite 流程

Andorid端模型嵌入过程如下:

训练TF 模型Freeze TF模型转换TF Lite模型TF Lite模型嵌入Android 端

下图展示模型转换图:

图3 TF模型转换 TF Lite 模型图

3.1 模型训练

由于TF Lite 仅仅支持有限的TF 算子,所以训练过程中,需要调整相关代码,使用TF Lite 支持的算子,防止TF 无法转TFLite 。下面给出TF Lite 主要支持的算子(更全更具体的信息在官网:https://tensorflow.google.cn/lite/guide/ops_compatibility):

许多TensorFlow操作可以由TensorFlow Lite处理,即使它们没有直接等效的。这种情况下的操作可以简单地从图中删除(tf.identity),替换为张量(tf.placeholder),或者融合成更复杂的操作(tf.nn.bias-add)。甚至某些受支持的操作有时也可以通过其中一个过程删除。

3.2 Freeze TF模型

由于Tensor Flow训练模型保存时,网络结构与模型参数分离,因此需要把网络结构与模型参数相结合为一个文件,此过程称为冷冻模型。

命令行工具():

freeze_graph --input_graph=xxx/xxx.pb \ # TF 模型文件 --input_checkpoint=xx/model-xxxxxx.data-00000-of-00001 \ # TF 模型参数 --input_binary=true \ --output_graph=xx/frozen_xxx.pb \ # 输出文件 --output_node_names=output/predictions # 网络中输出节点名

Python :

saver = tf.train.import_meta_graph ('xxx.meta') with tf.Session() as sess: saver.restore(sess, ’xxx.data’) # graph.ckpt.data-00000-of-00001 # 将计算图写入到模型文件中 : 'input_text', 'output/predictions' output_graph_def=graph_util.convert_variables_to_constants(sess, sess.graph_def, ['input_text', 'output/predictions']) model_f = tf.gfile.FastGFile('xxx.pb', mode="wb") model_f.write(output_graph_def.SerializeToString())

3.3 转换TF Lite模型

转换TF Lite 支持如下三种格式:

保存模型: 网络结构文件和模型参数文件以及相关输入输出变量;tf.keras : HDF5文件 包含模型参数与输入输出参数;冷冻 tf.GraphDef

3.3.1 转换保存模型

Python saver = tf.train.import_meta_graph( 'xxx.meta') with tf.Session() as sess: saver.restore(sess,’xxx.data’) # graph.ckpt.data-00000-of-00001 input_x = tf.get_default_graph().get_tensor_by_name("input_x:0") probs = tf.get_default_graph().get_tensor_by_name("output/predictions:0") input_arrays = [input_x, input_y] output_arrays = [probs] converter = tf.contrib.lite.TFLiteConverter.from_session(sess, input_arrays, output_arrays) tflite_model = converter.convert() open( "xxx.tflite", "wb").write(tflite_model)

3.3.2 转换 冷冻 tf.GraphDef

Python : input_arrays = ["input_x", "input_y"] output_arrays = ["output/predictions"] converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(fi_pb, input_arrays, output_arrays) tflite_model = converter.convert() open(‘xxx/xxx.tflite’, "wb").write(tflite_model)

综上,关于TF模型转化为移动端可以嵌入的压缩模型。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3